{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.]]),\n",
       " array([2, 1, 0, 1, 0]),\n",
       " (800, 784),\n",
       " (800,))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#加载数据\n",
    "def load_data():\n",
    "    #这个数据集,只有数字0,1,2,3,一共800行784列\n",
    "    with open('mnist_cut.csv') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    x = np.empty((len(lines), 784), dtype=float)\n",
    "    y = np.empty(len(lines), dtype=int)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i] = line[1:]\n",
    "        y[i] = line[0]\n",
    "\n",
    "    #归一化\n",
    "    x /= 255\n",
    "\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x[:5], y[:5], x.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 168\n",
      "1 225\n",
      "2 209\n",
      "3 198\n"
     ]
    }
   ],
   "source": [
    "#查看各类的数量\n",
    "for i in range(4):\n",
    "    print(i, (y == [i]).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "N, M = x.shape\n",
    "\n",
    "#这里训练的是ovo的多分类,一共6个二分类器,就是4个类两两排列组合\n",
    "params = ['01', '02', '03', '12', '13', '23']\n",
    "ws = {}\n",
    "bs = {}\n",
    "for i in params:\n",
    "    ws[i] = np.empty(M)\n",
    "    ws[i].fill(1 / M)\n",
    "    bs[i] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5360144532148693"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#预测函数,显然,他一次只能做一次二分类判断\n",
    "def predict(params, x):\n",
    "    z = ws[params].dot(x) + bs[params]\n",
    "    return 1 / (1 + np.exp(-z))\n",
    "\n",
    "\n",
    "predict('01', x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-278.37911686298537"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#这个log函数是为了避免log0.\n",
    "def log(p):\n",
    "    if p == 0:\n",
    "        p = 1e-20\n",
    "    return np.log(p)\n",
    "\n",
    "#loss函数\n",
    "def get_loss(params):\n",
    "    loss = 0\n",
    "    for i in range(N):\n",
    "        #只有是自己二分类中的一类才计算loss,其他类的loss计算没有意义\n",
    "        if str(y[i]) in params:\n",
    "            p = predict(params, x[i])\n",
    "            #二分类中的第一类是0,另一类是1\n",
    "            d = 0 if str(y[i]) == params[0] else 1\n",
    "            loss += d * log(p) + (1 - d) * log(1 - p)\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "get_loss('01')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.42421197,  0.68744269,  0.37266674,\n",
       "        0.02043476,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
       "        0.        ,  0.        ,  0.        , -0.10991191, -0.47420267,\n",
       "        0.14084069, -0.13980142, -1.05943653, -1.28035144, -2.15883399])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#暴力求梯度法\n",
    "#同样的,因为这个程序中实际是6个二分类器,所以这里的梯度,也要指定是哪一个的\n",
    "def get_gradient(params):\n",
    "    global w\n",
    "    global b\n",
    "    upsilon = 1e-2\n",
    "\n",
    "    gradient_w = np.empty(M)\n",
    "\n",
    "    for i in range(M):\n",
    "        l1 = get_loss(params)\n",
    "        ws[params][i] += upsilon\n",
    "        l2 = get_loss(params)\n",
    "        ws[params][i] -= upsilon\n",
    "        gradient_w[i] = (l2 - l1) / upsilon\n",
    "\n",
    "    l1 = get_loss(params)\n",
    "    bs[params] += upsilon\n",
    "    l2 = get_loss(params)\n",
    "    bs[params] -= upsilon\n",
    "    gradient_b = (l2 - l1) / upsilon\n",
    "\n",
    "    return gradient_w, gradient_b\n",
    "\n",
    "\n",
    "get_gradient('01')[0][:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "01 0 -0.09721305308799165\n",
      "01 10 -0.07259197823765838\n",
      "01 20 -0.0584861844270012\n",
      "01 30 -0.04940968066272462\n",
      "01 40 -0.04311613792244203\n",
      "02 0 -57.20651278101825\n",
      "02 10 -93.89839611020031\n",
      "02 20 -28.356635022821543\n",
      "02 30 -6.679625478918654\n",
      "02 40 -1.5635731356544724\n",
      "03 0 -19.554040865591535\n",
      "03 10 -2.7450048366824467\n",
      "03 20 -1.2179360222932507\n",
      "03 30 -0.8372873206165021\n",
      "03 40 -0.6699565754595802\n",
      "12 0 -55.64806841835235\n",
      "12 10 -5.870600759079906\n",
      "12 20 -2.6972055519559426\n",
      "12 30 -1.868802480537476\n",
      "12 40 -1.4838556615987815\n",
      "13 0 -26.619788831876587\n",
      "13 10 -4.020637036792894\n",
      "13 20 -2.47888129190784\n",
      "13 30 -1.8803402761977672\n",
      "13 40 -1.5489130504073492\n",
      "23 0 -402.3800096483134\n",
      "23 10 -85.59650846519047\n",
      "23 20 -55.80121607782987\n",
      "23 30 -35.128441712381175\n",
      "23 40 -19.732747827585374\n"
     ]
    }
   ],
   "source": [
    "#训练,因为是6个分类器,显然要训练6次\n",
    "for param in params:\n",
    "    for i in range(50):\n",
    "        g_w, g_b = get_gradient(param)\n",
    "        ws[param] += g_w * 1e-2\n",
    "        bs[param] += g_b * 1e-1\n",
    "\n",
    "        if i % 10 == 0:\n",
    "            print(param, i, get_loss(param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 1)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#最终结果由6个分类器投票表决，如果出现平局可以考虑每个分类器的置信度\n",
    "def vote(x):\n",
    "    votes = np.zeros(4)\n",
    "    for param in params:\n",
    "        p = predict(param, x)\n",
    "        result = param[0]\n",
    "        if p > 0.5:\n",
    "            result = param[1]\n",
    "        result = int(result)\n",
    "        votes[result] += 1\n",
    "    return votes.argmax()\n",
    "\n",
    "\n",
    "vote(x[1]), y[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.99625\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "correct = 0\n",
    "for i in range(N):\n",
    "    p = vote(x[i])\n",
    "    if p == y[i]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / N)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
